/*
 * Copyright 1999-2002 Carnegie Mellon University.  
 * Portions Copyright 2002 Sun Microsystems, Inc.  
 * Portions Copyright 2002 Mitsubishi Electric Research Laboratories.
 * Portions Copyright 2010 LIUM, University of Le Mans, France
 -> Anthony Rousseau, Teva Merlin, Yannick Esteve


 * All Rights Reserved.  Use is subject to license terms.
 * 
 * See the file "license.terms" for information on usage and
 * redistribution of this file, and for a DISCLAIMER OF ALL 
 * WARRANTIES.
 *
 */

package edu.cmu.sphinx.linguist.language.ngram.large;

import edu.cmu.sphinx.linguist.dictionary.Dictionary;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.Utilities;

import java.io.*;
import java.util.regex.Pattern;
import java.util.regex.Matcher;

/**
 * Reads a binary NGram language model file ("DMP file") generated by the SphinxBase sphinx_lm_convert.
 * <p>
 * Note that all probabilities in the grammar are stored in LogMath log base format. Language 
 * Probabilities in the language model file are stored in log 10 base. They are converted to 
 * the LogMath base.
 */

public class BinaryLoader {

    private static final String DARPA_TG_HEADER = "Darpa Trigram LM";
    private static final String DARPA_QG_HEADER = "Darpa Quadrigram LM";
    
    // For convenience, NG Header is regular expression, so there is 2 extra characters in it.
    // Therefore, header.length() must be adjusted by -1 (and not +1), 
    // and we use Pattern.matches() for equality in header names.
    private static final String DARPA_NG_HEADER = "Darpa \\d-gram LM";
    
    private static final int LOG2_NGRAM_SEGMENT_SIZE = 9;
    
    private static final float MIN_PROBABILITY = -99.0f;
    private static final int MAX_PROB_TABLE_SIZE = java.lang.Integer.MAX_VALUE;

    private LogMath logMath;
    
    private int maxNGram;
    
    private float unigramWeight;
    private float languageWeight;
    
    private double wip;
    
    private boolean bigEndian = true;
    private boolean applyLanguageWeightAndWip;

    private long bytesRead;

    private UnigramProbability[] unigrams;
    private String[] words;
    
    private long[] NGramOffset;
    private int[] numberNGrams;
    private int logNGramSegmentSize;
    
    private int startWordID;
    private int endWordID;
    
    private int[][] NGramSegmentTable;
    private float[][] NGramProbTable;
    private float[][] NGramBackoffTable;

    private RandomAccessFile file;

    // Bytes multiplier for LM (2 = 16 bits, 4 = 32 bits)
    private int bytesPerField;
    
    /**
     * Initializes the binary loader
     *
     * @param location                  location of the model
     * @param format                    file format
     * @param applyLanguageWeightAndWip if true apply language weight and word insertion penalty
     * @param languageWeight            language weight
     * @param wip                       word insertion probability
     * @param unigramWeight             unigram weight
     * @throws IOException if an I/O error occurs
     */
    public BinaryLoader(File location, String format,
                        boolean applyLanguageWeightAndWip,
                        float languageWeight, double wip, float unigramWeight)
            throws IOException {
        this(format, applyLanguageWeightAndWip, languageWeight, wip, unigramWeight);
        loadModelLayout(new FileInputStream (location));
        file = new RandomAccessFile(location, "r");
    }


    /**
     * Initializes the binary loader
     *
     * @param format                    file format
     * @param applyLanguageWeightAndWip if true apply language weight and word insertion penalty
     * @param languageWeight            language weight
     * @param wip                       word insertion probability
     * @param unigramWeight             unigram weight
     */
    public BinaryLoader(String format, boolean applyLanguageWeightAndWip, float
            languageWeight, double wip,
            float unigramWeight) {
        startWordID = -1;
        endWordID = -1;
        this.applyLanguageWeightAndWip = applyLanguageWeightAndWip;
        logMath = LogMath.getLogMath();
        this.languageWeight = languageWeight;
        this.wip = wip;
        this.unigramWeight = unigramWeight;
    }

    public void deallocate() throws IOException {
        if (null != file)
            file.close();
    }

    /**
     * Returns the number of unigrams
     *
     * @return the number of unigrams
     */
    public int getNumberUnigrams() {
        return getNumberNGrams(1);
    }


    /**
     * Returns the number of bigrams
     *
     * @return the number of bigrams
     */
    public int getNumberBigrams() {
    	return getNumberNGrams(2);
    }


    /**
     * Returns the number of trigrams
     *
     * @return the number of trigrams
     */
    public int getNumberTrigrams() {
    	return getNumberNGrams(3);
    }
    
    
    /**
     * Returns the number of NGrams at
     * a specified N order.
     *
     * @param n			the desired order
     * @return the number of NGrams
     */
    public int getNumberNGrams(int n) {
    	// Be sure that we don't overcome the model
    	assert (n <= maxNGram) & (n > 0);
        return numberNGrams[n - 1];
    }

    
    /**
     * Returns all the unigrams
     *
     * @return all the unigrams
     */
    public UnigramProbability[] getUnigrams() {
        return unigrams;
    }


    /**
     * Returns all the bigram probabilities.
     *
     * @return all the bigram probabilities
     */
    public float[] getBigramProbabilities() {
        return getNGramProbabilities(2);
    }


    /**
     * Returns all the trigram probabilities.
     *
     * @return all the trigram probabilities
     */
    public float[] getTrigramProbabilities() {
        return getNGramProbabilities(3);
    }


    /**
     * Returns all the trigram backoff weights
     *
     * @return all the trigram backoff weights
     */
    public float[] getTrigramBackoffWeights() {
        return getNGramBackoffWeights(3);
    }


    /**
     * Returns the trigram segment table.
     *
     * @return the trigram segment table
     */
    public int[] getTrigramSegments() {
        return getNGramSegments(3);
    }


    /**
     * Returns the log of the bigram segment size
     *
     * @return the log of the bigram segment size
     */
    public int getLogBigramSegmentSize() {
        return logNGramSegmentSize;
    }
    
    
    /**
     * Returns all the NGram probabilities at
     * a specified N order.
     *
     * @param n			the desired order
     * @return all the NGram probabilities
     */
    public float[] getNGramProbabilities(int n) {
    	// Be sure that we don't overcome the model
    	assert (n <= maxNGram) && (n > 1);
        return NGramProbTable[n - 1];
    }

    
    /**
     * Returns all the NGram backoff weights at
     * a specified N order.
     *
     * @param n			the desired order
     * @return all the NGram backoff weights
     */
    public float[] getNGramBackoffWeights(int n) {
    	// Be sure that we don't overcome the model
    	assert (n <= maxNGram) & (n > 2);
        return NGramBackoffTable[n - 1];
    }


    /**
     * Returns the NGram segment table at
     * a specified order.
     *
     * @param n			the desired order
     * @return the NGram segment table
     */
    public int[] getNGramSegments(int n) {
    	// Be sure that we don't overcome the model
    	assert (n <= maxNGram) & (n > 2);
        return NGramSegmentTable[n - 1];
    }


    /**
     * Returns the log of the NGram segment size
     *
     * @return the log of the NGram segment size
     */
    public int getLogNGramSegmentSize() {
        return logNGramSegmentSize;
    }


    /**
     * Returns all the words.
     *
     * @return all the words
     */
    public String[] getWords() {
        return words;
    }


    /**
     * Returns the location (or offset) into the file where bigrams start.
     *
     * @return the location of the bigrams
     */
    public long getBigramOffset() {
        return getNGramOffset(2);
    }


    /**
     * Returns the location (or offset) into the file where trigrams start.
     *
     * @return the location of the trigrams
     */
    public long getTrigramOffset() {
        return getNGramOffset(3);
    }
    
    
    /**
     * Returns the location (or offset) into the file where NGrams start
     * at a specified N order.
     *
     * @param n			the desired order
     * @return the location of the bigrams
     */
    public long getNGramOffset(int n) {
    	// Be sure that we don't overcome the model
    	assert (n <= maxNGram) & (n > 1);
        return NGramOffset[n - 1];
    }


    /**
     * Returns the maximum depth of the language model
     *
     * @return the maximum depth of the language model
     */
    public int getMaxDepth() {
        return maxNGram;
    }


    /**
     * Returns true if the loaded file is in big-endian.
     *
     * @return true if the loaded file is big-endian
     */
    public boolean getBigEndian() {
        return bigEndian;
    }


    /**
     * Returns the multiplier for the size of a NGram
     * (1 for 16 bits, 2 for 32 bits).
     *
     * @return the multiplier for the size of a NGram
     */
    public int getBytesPerField() {
        return bytesPerField;
    }
    
    
    /**
     * Loads the contents of the memory-mapped file starting at the given position and for the given size, into a byte
     * buffer. This method is implemented because MappedByteBuffer.load() does not work properly.
     *
     * @param position the starting position in the file
     * @param size     the number of bytes to load
     * @return the loaded ByteBuffer
     * @throws java.io.IOException if IO went wrong
     */
    public byte[] loadBuffer(long position, int size) throws IOException {
        // assert ((position + size) <= fileChannel.size());
        file.seek(position);
        byte[] bytes = new byte[size];
        if (file.read(bytes) != size) {
            throw new IOException("Incorrect number of bytes read. Size = " + size + ". Position =" + position + ".");
        }
        return bytes;
    }


    /**
     * Loads the language model from the given file.
     *
     * @param inputStream stream to read the language model data
     * @throws java.io.IOException if IO went wrong
     */
    protected void loadModelLayout(InputStream inputStream) throws IOException {

        DataInputStream stream = new DataInputStream
                (new BufferedInputStream(inputStream));

        // read standard header string-size; set bigEndian flag
        readHeader(stream);

        // +1 is the sentinel unigram at the end
        unigrams = readUnigrams(stream, numberNGrams[0] + 1, bigEndian);
        
        skipNGrams(stream);

        // Read the NGram prob & bow tables
        for (int i = 1; i < maxNGram; i++) {      	
        	if (numberNGrams[i] > 0) {
        		if (i == 1) {
        			NGramProbTable[i] = readFloatTable(stream, bigEndian);
        		}
        		else {
        			NGramBackoffTable[i] = readFloatTable(stream, bigEndian);
        			NGramProbTable[i] = readFloatTable(stream, bigEndian);
        			
        			int nMinus1gramSegmentSize = 1 << logNGramSegmentSize;
        			int NGramSegTableSize = ((numberNGrams[i - 1] + 1) / nMinus1gramSegmentSize) + 1;
        			
        			NGramSegmentTable[i] = readIntTable(stream, bigEndian, NGramSegTableSize);
        		}
        	}
        }
        
        // read word string names
        int wordsStringLength = readInt(stream, bigEndian);
        if (wordsStringLength <= 0) {
            throw new Error("Bad word string size: " + wordsStringLength);
        }

        // read the string of all words
        this.words = readWords(stream, wordsStringLength, numberNGrams[0]);

        // A voir
        if (startWordID > -1) {
            UnigramProbability unigram = unigrams[startWordID];
            unigram.setLogProbability(MIN_PROBABILITY);
        }
        if (endWordID > -1) {
            UnigramProbability unigram = unigrams[endWordID];
            unigram.setLogBackoff(MIN_PROBABILITY);
        }

        applyUnigramWeight();
        
        if (applyLanguageWeightAndWip) {
        	for (int i = 0; i <= maxNGram; i++) {
        		applyLanguageWeight(NGramProbTable[i], languageWeight);
        		applyWip(NGramProbTable[i], wip);
        		
        		if (i > 1) {
        			applyLanguageWeight(NGramBackoffTable[i], languageWeight);
        		}
        	}
        }

        stream.close();
    }


    /**
     * Reads the LM file header
     *
     * @param stream the data stream of the LM file
     * @throws java.io.IOException
     */
    private void readHeader(DataInputStream stream) throws IOException {
        int headerLength = readInt(stream, bigEndian);

        if ((headerLength != DARPA_TG_HEADER.length() + 1) && (headerLength != DARPA_QG_HEADER.length() + 1) && (headerLength != DARPA_NG_HEADER.length() - 1)) { // not big-endian
            headerLength = Utilities.swapInteger(headerLength);
            if (headerLength == (DARPA_TG_HEADER.length() + 1) || headerLength == (DARPA_QG_HEADER.length() + 1) || headerLength == (DARPA_NG_HEADER.length() - 1)) {
                bigEndian = false;
                // System.out.println("Little-endian");
            } else {
                throw new Error("Bad binary LM file magic number: "
                        + headerLength + ", not an LM dumpfile?");
            }
        } else {
            // System.out.println("Big-endian");
        }

        // read and verify standard header string
        String header = readString(stream, headerLength - 1);
        stream.readByte(); // read the '\0'
        bytesRead++;

        if (!header.equals(DARPA_TG_HEADER) && !header.equals(DARPA_QG_HEADER) && !Pattern.matches(DARPA_NG_HEADER, header)) {
            throw new Error("Bad binary LM file header: " + header);
        }
        else {
        	if (header.equals(DARPA_TG_HEADER))
        		maxNGram = 3;
        	else if (header.equals(DARPA_QG_HEADER))
        		maxNGram = 4;
        	else {
        		Pattern p = Pattern.compile("\\d");
        		Matcher m = p.matcher(header);
        		maxNGram = Integer.parseInt(m.group());
        	}
        }
        
        // read LM filename string size and string
        int fileNameLength = readInt(stream, bigEndian);
        skipStreamBytes(stream, fileNameLength);

        numberNGrams = new int[maxNGram];
        NGramOffset = new long[maxNGram];
        NGramProbTable = new float[maxNGram][];
        NGramBackoffTable = new float[maxNGram][];
        NGramSegmentTable = new int[maxNGram][];
        
        numberNGrams[0] = 0;
        logNGramSegmentSize = LOG2_NGRAM_SEGMENT_SIZE;

        // read version number, if present. it must be <= 0.

        int version = readInt(stream, bigEndian);
        // System.out.println("Version: " + version);

        bytesPerField = 2;
        
        if (version <= 0) { // yes, its the version number
            readInt(stream, bigEndian); // read and skip timestamp

            // Means we are going 32 bits.
            if (version <= -3)
            	bytesPerField = 4;
            
            // read and skip format description
            int formatLength;
            for (; ;) {
                if ((formatLength = readInt(stream, bigEndian)) == 0) {
                    break;
                }
                bytesRead += stream.skipBytes(formatLength);
            }

            // read log NGram segment size if present
            // only for 16 bits version 2 LM
            if (version == -2) {
                logNGramSegmentSize = readInt(stream, bigEndian);
                if (logNGramSegmentSize < 1 || logNGramSegmentSize > 15) {
                    throw new Error("log2(bg_seg_sz) outside range 1..15");
                }
            }

            numberNGrams[0] = readInt(stream, bigEndian);
        } else {
        	numberNGrams[0] = version;
        }

        if (numberNGrams[0] <= 0) {
            throw new Error("Bad number of unigrams: " + numberNGrams[0]
                    + ", must be > 0.");
        }

        for (int i = 1; i < maxNGram; i++) {
            if ((numberNGrams[i] = readInt(stream, bigEndian)) < 0) {
                throw new Error("Bad number of " + String.valueOf(i) + "-grams: " + numberNGrams[i]);
            }
        }
    }


    /**
     * Skips the NGrams of the LM.
     * 
     * @param stream
     *            the source of data
     * @throws java.io.IOException
     */
    private void skipNGrams(DataInputStream stream) throws IOException {
        long bytesToSkip;

        NGramOffset[1] = bytesRead;
        bytesToSkip = (numberNGrams[1] + 1) * LargeNGramModel.BYTES_PER_NGRAM * getBytesPerField();
        skipStreamBytes(stream, bytesToSkip);

        for (int i = 2; i < maxNGram; i++) {
            if (numberNGrams[i] > 0 && i < maxNGram - 1) {
                NGramOffset[i] = bytesRead;
                bytesToSkip = (long) (numberNGrams[i] + 1) * (long) LargeNGramModel.BYTES_PER_NGRAM * getBytesPerField();
                skipStreamBytes(stream, bytesToSkip);
            } else if (numberNGrams[i] > 0 && i == maxNGram - 1) {
                NGramOffset[i] = bytesRead;
                bytesToSkip = (long) (numberNGrams[i]) * (long) LargeNGramModel.BYTES_PER_NMAXGRAM * getBytesPerField();
                skipStreamBytes(stream, bytesToSkip);
            }
        }
    }
    
    /**
     * Reliable skip
     * 
     * @param stream stream
     * @param bytes number of bytes
     */
    private void skipStreamBytes(DataInputStream stream, long bytes) throws IOException {
        while (bytes > 0) {
            long skipped = stream.skip(bytes);
            bytesRead += skipped;
            bytes -= skipped;
        }
    }
    

    /** Apply the unigram weight to the set of unigrams */
    private void applyUnigramWeight() {
        float logUnigramWeight = logMath.linearToLog(unigramWeight);
        float logNotUnigramWeight = logMath.linearToLog(1.0f - unigramWeight);
        float logUniform = logMath.linearToLog(1.0f / (numberNGrams[0]));

        float logWip = logMath.linearToLog(wip);

        float p2 = logUniform + logNotUnigramWeight;

        for (int i = 0; i < numberNGrams[0]; i++) {
            UnigramProbability unigram = unigrams[i];

            float p1 = unigram.getLogProbability();

            if (i != startWordID) {
                p1 += logUnigramWeight;
                p1 = logMath.addAsLinear(p1, p2);
            }

            if (applyLanguageWeightAndWip) {
                p1 = p1 * languageWeight + logWip;
                unigram.setLogBackoff(unigram.getLogBackoff() * languageWeight);
            }

            unigram.setLogProbability(p1);
        }
    }


    /** Apply the language weight to the given array of probabilities.
     */
    private void applyLanguageWeight(float[] logProbabilities,
                                     float languageWeight) {
        for (int i = 0; i < logProbabilities.length; i++) {
            logProbabilities[i] = logProbabilities[i] * languageWeight;
        }
    }


    /** Apply the WIP to the given array of probabilities.
    */
    private void applyWip(float[] logProbabilities, double wip) {
        float logWip = logMath.linearToLog(wip);
        for (int i = 0; i < logProbabilities.length; i++) {
            logProbabilities[i] = logProbabilities[i] + logWip;
        }
    }


    /**
     * Reads the probability table from the given DataInputStream.
     *
     * @param stream    the DataInputStream from which to read the table
     * @param bigEndian true if the given stream is bigEndian, false otherwise
     * @throws java.io.IOException
     */
    private float[] readFloatTable(DataInputStream stream, boolean bigEndian)
            throws IOException {

        int numProbs = readInt(stream, bigEndian);
        if (numProbs <= 0 || numProbs > MAX_PROB_TABLE_SIZE) {
            throw new Error("Bad probabilities table size: " + numProbs);
        }

        float[] probTable = new float[numProbs];

        for (int i = 0; i < numProbs; i++) {
        	//probTable[i] = readFloat(stream, bigEndian);
            probTable[i] = logMath.log10ToLog(readFloat(stream, bigEndian));
        }

        return probTable;
    }


    /**
     * Reads a table of integers from the given DataInputStream.
     *
     * @param stream    the DataInputStream from which to read the table
     * @param bigEndian true if the given stream is bigEndian, false otherwise
     * @param tableSize the size of the NGram segment table
     * @return the NGram segment table, which is an array of integers
     * @throws java.io.IOException
     */
    private int[] readIntTable(DataInputStream stream, boolean bigEndian,
                               int tableSize) throws IOException {
        int numSegments = readInt(stream, bigEndian);
        if (numSegments != tableSize) {
            throw new Error("Bad NGram seg table size: " + numSegments);
        }
        int[] segmentTable = new int[numSegments];
        for (int i = 0; i < numSegments; i++) {
            segmentTable[i] = readInt(stream, bigEndian);
        }
        return segmentTable;
    }


    /**
     * Read in the unigrams in the given DataInputStream.
     *
     * @param stream         the DataInputStream to read from
     * @param numberUnigrams the number of unigrams to read
     * @param bigEndian      true if the DataInputStream is big-endian, false otherwise
     * @return an array of UnigramProbability index by the unigram ID
     * @throws java.io.IOException
     */
    private UnigramProbability[] readUnigrams(DataInputStream stream,
                                              int numberUnigrams, boolean bigEndian) throws IOException {

        UnigramProbability[] unigrams = new UnigramProbability[numberUnigrams];

        for (int i = 0; i < numberUnigrams; i++) {

            // read unigram ID, unigram probability, unigram backoff weight
            int unigramID = readInt(stream, bigEndian);
            
            /* Some tools to convert to DMP doesn't store ID's in unigrams */
            if (unigramID < 1) 
            	unigramID = i;

            // if we're not reading the sentinel unigram at the end,
            // make sure that the unigram IDs are consecutive
            if (i != (numberUnigrams - 1)) {
                assert (unigramID == i);
            }

            float unigramProbability = readFloat(stream, bigEndian);
            float unigramBackoff = readFloat(stream, bigEndian);
            int firstBigramEntry = readInt(stream, bigEndian);

            float logProbability = logMath.log10ToLog(unigramProbability);
            float logBackoff = logMath.log10ToLog(unigramBackoff);

            unigrams[i] = new UnigramProbability(unigramID, logProbability,
                    logBackoff, firstBigramEntry);
        }

        return unigrams;
    }


    /**
     * Reads an integer from the given DataInputStream.
     *
     * @param stream    the DataInputStream to read from
     * @param bigEndian true if the DataInputStream is in bigEndian, false otherwise
     * @return the integer read
     * @throws java.io.IOException
     */
    private int readInt(DataInputStream stream, boolean bigEndian)
            throws IOException {
        bytesRead += 4;
        if (bigEndian) {
            return stream.readInt();
        } else {
            return Utilities.readLittleEndianInt(stream);
        }
    }


    /**
     * Reads a float from the given DataInputStream.
     *
     * @param stream    the DataInputStream to read from
     * @param bigEndian true if the DataInputStream is in bigEndian, false otherwise
     * @return the float read
     * @throws java.io.IOException
     */
    private float readFloat(DataInputStream stream, boolean bigEndian)
            throws IOException {
        bytesRead += 4;
        if (bigEndian) {
            return stream.readFloat();
        } else {
            return Utilities.readLittleEndianFloat(stream);
        }
    }



     /**
      * Reads a string of the given length from the given DataInputStream. It is assumed that the DataInputStream
      * contains 8-bit chars.
      *
      * @param stream the DataInputStream to read from
      * @param length the number of characters in the returned string
      * @return a string of the given length from the given DataInputStream
      * @throws java.io.IOException
      */
     private String readString(DataInputStream stream, int length)
             throws IOException {
         StringBuilder builder = new StringBuilder();
         byte[] bytes = new byte[length];
         bytesRead += stream.read(bytes);
 
         for (int i = 0; i < length; i++) {
             builder.append((char) bytes[i]);
         }
         return builder.toString();
     }
 
 
     /**
      * Reads a series of consecutive Strings from the given stream.
      *
      * @param stream         the DataInputStream to read from
      * @param length         the total length in bytes of all the Strings
      * @param numberUnigrams the number of String to read
      * @return an array of the Strings read
      * @throws java.io.IOException
      */
     private String[] readWords(DataInputStream stream, int length,
                                      int numberUnigrams) throws IOException {
         String[] words = new String[numberUnigrams];
         byte[] bytes = new byte[length];
         bytesRead += stream.read(bytes);
 
         int s = 0;
         int wordStart = 0;
         for (int i = 0; i < length; i++) {
             char c = (char) (bytes[i] & 0xFF);
             bytesRead++;
             if (c == '\0') {
                 // if its the end of a string, add it to the 'words' array
                 words[s] = new String(bytes, wordStart, i - wordStart);
                 wordStart = i + 1;
                 if (words[s].equals(Dictionary.SENTENCE_START_SPELLING)) {
                     startWordID = s;
                 } else if (words[s].equals(Dictionary.SENTENCE_END_SPELLING)) {
                     endWordID = s;
                 }
                 s++;
             }
         }
         assert (s == numberUnigrams);
         return words;
     }
     
}
